import math
from collections import OrderedDict

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dm_env import specs

import utils
from agent.ddpg import DDPGAgent


class CIM(nn.Module):
    def __init__(self, obs_dim, skill_dim, hidden_dim):
        super().__init__()
        self.obs_dim = obs_dim
        self.skill_dim = skill_dim

        self.state_net = nn.Sequential(
            nn.Linear(self.obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.skill_dim),
        )

        self.apply(utils.weight_init)

    def forward(self, state, next_state, skill):
        assert len(state.size()) == len(next_state.size())
        state = self.state_net(state)
        next_state = self.state_net(next_state)
        query = skill
        key = next_state - state
        return query, key


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class RMS(object):
    def __init__(self, epsilon=1e-4, shape=(1,)):
        self.M = torch.zeros(shape).to(device)
        self.S = torch.ones(shape).to(device)
        self.n = epsilon

    def __call__(self, x):
        bs = x.size(0)
        delta = torch.mean(x, dim=0) - self.M
        new_M = self.M + delta * bs / (self.n + bs)
        new_S = (
            self.S * self.n
            + torch.var(x, dim=0) * bs
            + (delta**2) * self.n * bs / (self.n + bs)
        ) / (self.n + bs)

        self.M = new_M
        self.S = new_S
        self.n += bs

        return self.M, self.S


class APTArgs:
    def __init__(
        self,
        knn_k=16,
        knn_avg=True,
        rms=True,
        knn_clip=0.0005,
    ):
        self.knn_k = knn_k
        self.knn_avg = knn_avg
        self.rms = rms
        self.knn_clip = knn_clip


rms = RMS()


def compute_apt_reward(source, target, args):
    b1, b2 = source.size(0), target.size(0)
    # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2)
    sim_matrix = torch.norm(
        source[:, None, :].view(b1, 1, -1) - target[None, :, :].view(1, b2, -1),
        dim=-1,
        p=2,
    )
    reward, _ = sim_matrix.topk(
        args.knn_k, dim=1, largest=False, sorted=True
    )  # (b1, k)

    if not args.knn_avg:  # only keep k-th nearest neighbor
        reward = reward[:, -1]
        reward = reward.reshape(-1, 1)  # (b1, 1)
        if args.rms:
            moving_mean, moving_std = rms(reward)
            reward = reward / moving_std
        reward = torch.max(
            reward - args.knn_clip, torch.zeros_like(reward).to(device)
        )  # (b1, )
    else:  # average over all k nearest neighbors
        reward = reward.reshape(-1, 1)  # (b1 * k, 1)
        if args.rms:
            moving_mean, moving_std = rms(reward)
            reward = reward / moving_std
        reward = torch.max(reward - args.knn_clip, torch.zeros_like(reward).to(device))
        reward = reward.reshape((b1, args.knn_k))  # (b1, k)
        reward = reward.mean(dim=1)  # (b1,)
    reward = torch.log(reward + 1.0)
    return reward


class CIMAgent(DDPGAgent):
    # Contrastive Intrinsic Control (CIC)
    def __init__(
        self,
        update_skill_every_step,
        skill_dim,
        scale,
        project_skill,
        rew_type,
        update_rep,
        temp,
        update_encoder,
        num_init_steps,
        discount,
        **kwargs
    ):
        self.temp = temp
        self.skill_dim = skill_dim
        self.update_skill_every_step = update_skill_every_step
        self.scale = scale
        self.project_skill = project_skill
        self.rew_type = rew_type
        self.update_rep = update_rep
        self.update_encoder = update_encoder
        self.num_init_steps = num_init_steps
        self.discount = discount
        kwargs["meta_dim"] = self.skill_dim
        self.solved_meta = None
        # create actor and critic

        super().__init__(**kwargs)
        # create cim first
        self.cim = CIM(self.obs_dim - skill_dim, skill_dim, kwargs["hidden_dim"]).to(
            kwargs["device"]
        )

        # optimizers
        self.cim_optimizer = torch.optim.Adam(self.cim.parameters(), lr=self.lr)

        self.cim.train()

    def get_meta_specs(self):
        return (specs.Array((self.skill_dim,), np.float32, "skill"),)

    def init_meta(self):
        if not self.reward_free:
            if self.solved_meta is not None:
                return self.solved_meta
            # selects mean skill of 0.5 (to select skill automatically use CEM or Grid Sweep
            # procedures described in the CIC paper)
            skill = np.random.uniform(-1, 1, self.skill_dim).astype(np.float32)
            skill = skill / (np.linalg.norm(skill) + 1e-8)
        else:
            skill = np.random.uniform(-1, 1, self.skill_dim).astype(np.float32)
        meta = OrderedDict()
        meta["skill"] = skill
        return meta

    def update_meta(self, meta, step, time_step):
        if step % self.update_skill_every_step == 0:
            return self.init_meta()
        return meta

    def compute_cpc_loss(self, obs, next_obs, skill):
        query, key = self.cim.forward(obs, next_obs, skill)
        S = key.mm(query.T).T
        return -(S.diag() - S.logsumexp(1))

    def update_cim(self, obs, skill, next_obs, step):
        metrics = dict()

        loss = self.compute_cpc_loss(obs, next_obs, skill)
        loss = loss.mean()
        self.cim_optimizer.zero_grad()
        loss.backward()
        self.cim_optimizer.step()

        if self.use_tb or self.use_wandb:
            metrics["cim_loss"] = loss.item()
            # metrics["cim_logits"] = logits.norm()

        return metrics

    def compute_intr_reward(self, obs, skill, next_obs, step):
        with torch.no_grad():
            loss, logits = self.compute_cpc_loss(obs, next_obs, skill)

        reward = loss
        reward = reward.clone().detach().unsqueeze(-1)

        return reward * self.scale

    @torch.no_grad()
    def compute_apt_reward(self, obs, skills):
        args = APTArgs()
        # source = self.cim.state_net(obs)
        # target = self.cim.state_net(next_obs)
        # reward = compute_apt_reward(source, target, args)  # (b,)
        reward = []
        for i in range(len(obs)):
            _obs = self.cim.state_net(obs[i])
            skill = skills[i][-1:]
            source = target = (_obs[1 + self.nstep :] * skill).sum(-1, True).clamp(0)
            reward.append(compute_apt_reward(source, target, args).squeeze())
        reward = torch.cat(reward).unsqueeze(-1)
        return reward  # (b,1)

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)

        obs, action, extr_reward, discount, next_obs, skill = utils.to_torch(
            batch[:-2], self.device
        )

        if self.reward_free:
            episodes, idxs = batch[-2:]
            n_episodes = int(
                np.ceil(len(obs) / (self.update_skill_every_step - 1 - self.nstep))
            )
            batch_size = n_episodes * (self.update_skill_every_step - 1 - self.nstep)
            ep_obs = episodes["observation"][:n_episodes].to(self.device)
            ep_skill = episodes["skill"][:n_episodes].to(self.device)
            ep_act = episodes["action"][:n_episodes].to(self.device)
            ep_discount = episodes["discount"][:n_episodes].to(self.device)
            obs = ep_obs[:, 1 : -self.nstep].reshape(batch_size, -1)
            next_obs = ep_obs[:, 1 + self.nstep :].reshape(batch_size, -1)
            skill = ep_skill[:, 1 : -self.nstep].reshape(batch_size, -1)
            action = ep_act[:, 2 : 1 - self.nstep].reshape(batch_size, -1)
            discount = (
                ep_discount[:, 2 : 1 - self.nstep].reshape(batch_size, -1)
                * self.discount**self.nstep
            )

        with torch.no_grad():
            obs = self.aug_and_encode(obs)

            next_obs = self.aug_and_encode(next_obs)

        if self.reward_free:
            if self.update_rep:
                metrics.update(self.update_cim(obs, skill, next_obs, step))

            intr_reward = self.compute_apt_reward(ep_obs, ep_skill)
            if self.use_tb or self.use_wandb:
                metrics["intr_reward"] = intr_reward.mean().item()

            reward = intr_reward
        else:
            reward = extr_reward

        if self.use_tb or self.use_wandb:
            metrics["extr_reward"] = extr_reward.mean().item()
            metrics["batch_reward"] = reward.mean().item()

        if not self.update_encoder:
            obs = obs.detach()
            next_obs = next_obs.detach()

        # extend observations with skill
        obs = torch.cat([obs, skill], dim=1)
        next_obs = torch.cat([next_obs, skill], dim=1)

        # update critic
        metrics.update(
            self.update_critic(
                obs.detach(), action, reward, discount, next_obs.detach(), step
            )
        )

        # update actor
        metrics.update(self.update_actor(obs.detach(), step))

        # update critic target
        utils.soft_update_params(
            self.critic, self.critic_target, self.critic_target_tau
        )

        return metrics

    @torch.no_grad()
    def regress_meta(self, replay_iter, step):
        if self.solved_meta is not None:
            return self.solved_meta
        ep_reward = []
        ep_skill = []
        batch_size = 0
        while batch_size < self.num_init_steps:
            batch = next(replay_iter)
            ep_reward.append(batch[-2]["reward"].squeeze().sum(-1))
            ep_skill.append(batch[-3])
            batch_size += batch[-3].size(0)
        ep_reward, ep_skill = torch.cat(ep_reward), torch.cat(ep_skill)

        meta = OrderedDict()
        meta["skill"] = ep_skill[torch.argmax(ep_reward)].cpu().numpy()

        # save for evaluation
        self.solved_meta = meta
        print("solved skill is ", self.solved_meta)
        return meta
